
import weld.compile

from abc import ABC, abstractmethod
from collections import namedtuple
import functools

class NodeId(object):
    """
    A node ID, which provides a unique name for a node in a tree.
    """
    __slots__ = ['name']

    def __eq__(self, other):
        return self.name == other.name

    def __hash__(self):
        return hash(self.name)

    def __init__(self, name):
        self.name = name

    def __str__(self):
        return self.name

class WeldNode(ABC):
    """
    Base class for nodes encapsulating a DAG of Weld computations.
    """

    # ---------------------- Abstract Methods ------------------------------

    @property
    @abstractmethod
    def children(self):
        """ List of nodes this node depends on. """
        pass

    @property
    @abstractmethod
    def output_type(self):
        """ The Weld output type of this node. """
        pass

    @abstractmethod
    def evaluate(self):
        """
        Return a concrete result from this Weld computation.
        """
        pass


    # ---------------------- Provided Methods ------------------------------

    @classmethod
    def prefix(cls):
        """
        Prefix used for naming identifiers generated by this node.

        By default, this is the class name lowercased. This can be overridden.

        """
        return cls.__name__.lower()

    @classmethod
    def counter_(cls):
        if not hasattr(cls, "counter_value_"):
            setattr(cls, "counter_value_", 0)
        return getattr(cls, "counter_value_")

    @classmethod
    def set_counter_(cls, value):
        setattr(cls, "counter_value_", value)

    @classmethod
    def generate_id(cls):
        """ Generates a unique ID for this node. """
        cur_value = cls.counter_()
        cur_value += 1
        cls.set_counter_(cur_value)
        return NodeId("{0}{1}".format(cls.prefix(), cur_value))

    @property
    def id(self):
        if not hasattr(self, "node_id_"):
            setattr(self, "node_id_", self.generate_id())
        return getattr(self, "node_id_")

    def _walk_bottomup(self, f, context, visited):
        """ Recursive bottom up DAG walk implementation. """
        if self in visited:
            return
        visited.add(self)
        for dep in self.children:
            dep._walk_bottomup(f, context, visited)
        f(self, context)

    def walk(self, f, context, mode="bottomup"):
        """ Walk the DAG in the specified order.

        Each node in the DAG is visited exactly once.

        Parameters
        __________

        f : A function to apply to each record. The function takes an operation
        and an optional context (i.e., any object) as arguments.

        context : An initial context.

        mode : The order in which to process the DAG. "topdown" (the default)
        traverses each node as its visited in breadth-first order. "bottomup"
        traverses the graph depth-first, so the roots are visited after the
        leaves (i.e., nodes are represented in "execution order" where
        dependencies are processed first).

        """

        if mode == "bottomup":
            return self._walk_bottomup(f, context, set())

        assert mode == "topdown"

        visited = set()
        queue = deque([self])
        while len(queue) != 0:
            cur = queue.popleft()
            if cur not in visited:
                f(cur, context)
                visited.add(cur)
                for child in cur.children:
                    queue.append(child)

    def __eq__(self, other):
        return self.id == other.id

    def __hash__(self):
        return hash(self.id)


class PhysicalValue(WeldNode):
    """
    A physical value that a lazy computation depends on.
    """
    def __init__(self, value, ty, encoder):
        self.value = value
        self.encoder = encoder
        self.ty_ = ty

    @classmethod
    def prefix(cls):
        return "inp"

    @property
    def children(self):
        return []

    @property
    def output_type(self):
        return self.ty_

    def evaluate(self):
        return value

class WeldLazy(WeldNode):
    """
    A lazy value that encapsulates a Weld computation.
    """

    def __init__(self, expression, dependencies, ty, decoder):
        """
        Creates a new lazy Weld computation.

        Parameters
        ----------
        expression : str
            A weld expression.
        dependencies : list[WeldNode]
            A list of dependencies. The expression should only use names
            from this list.
        ty : WeldType
            The output type of this computation.
        decoder : A decoder for decoding the Weld result of this computation.

        """
        # Remove duplicates here
        self.children_ = list(set(dependencies))
        self.expression = expression
        self.decoder = decoder
        self.ty_ = ty

        # Cache the compiled program.
        self.program_ = None
        # Cache the code
        self.code_ = None

    @property
    def children(self):
        return self.children_

    @property
    def num_dependencies(self):
        """
        Returns the total number of dependencies this computation relies on.
        This does not count `self` as a dependency.
        """
        def increment(node, count):
            count[0] += 1
        count = [0]
        self.walk(increment, count)
        return count[0] - 1

    @property
    def output_type(self):
        return self.ty_

    @property
    def code(self):
        self._assemble(compile=False)
        return self.code_

    @property
    def is_identity(self):
        return len(self.children) == 1 and isinstance(self.children[0], PhysicalValue)

    def _create_function_header(self, inputs):
        arguments = ["{0}: {1}".format(inp.id, str(inp.output_type)) for inp in inputs]
        return "|" + ", ".join(arguments) + "|"

    def _inputs(self, nodes_to_execute):
        """
        Returns an ordered list of inputs to this computation.

        """
        # Inputs are PhysicalValue objects.
        inputs = [node for node in nodes_to_execute if isinstance(node, PhysicalValue)]
        inputs.sort(key=lambda e: e.id.name)
        return inputs

    def _assemble(self, compile=False):
        """
        'Assembles' the program by constructing code from the children and optionally
        compiling it. Returns the program inputs.

        Parameters
        ----------
        compile : boolean
            If true, compiles the code.

        Post-conditions
        ---------------

        self.code_ is not None.
        if compile is True, self.program_ is not None.

        Returns
        -------
        list[inputs]

        """
        # Collect nodes in execution order.
        nodes_to_execute = []
        self.walk(lambda node, expressions: expressions.append(node), nodes_to_execute)
        inputs = self._inputs(nodes_to_execute)

        if self.code_ is not None and self.program_ is not None:
            return inputs
        if self.code_ is not None and not compile:
            return inputs

        arg_types = [inp.output_type for inp in inputs]
        encoders = [inp.encoder for inp in inputs]

        # Collect the expressions from the remaining nodes.
        expressions = [
                "let {name} = ({expr});".format(name=node.id, expr=node.expression)\
                        for node in nodes_to_execute if isinstance(node, WeldLazy)]
        assert nodes_to_execute[-1] is self
        expressions.append(str(self.id))

        self.code_ = self._create_function_header(inputs) + " " + "\n".join(expressions)
        if compile:
            self.program_ = weld.compile.compile(self.code_, arg_types, encoders, self.output_type, self.decoder)
        return inputs

    def evaluate(self):
        """
        Evaluate this computation by compiling and running the encapsulated Weld program.

        """
        if self.is_identity:
            # This is a physical value -- return an empty context. The physical value
            # should manage its own context if its owned by Weld.
            return (self.children[0].value, None)

        inputs = self._assemble(compile=True)
        values = [inp.value for inp in inputs]
        return (self.program_)(*values)

def identity(phys_value, decoder):
    """
    Creates an identity `WeldLazy` from a `PhysicalValue` and decoder.
    """
    return WeldLazy(str(phys_value.id), [phys_value], phys_value.output_type, decoder)

class WeldFunc(object):
    """
    The result of a function annotated with `weldfunc`.

    This object is effectively a function with a field `code`, which extracts
    the raw code constructed by the annotated function without constructing a
    `WeldLazy`.

    """

    __slots__ = ["_code", "_dependencies"]

    def __init__(self, code, dependencies):
        self._code = code
        self._dependencies = dependencies

    @property
    def code(self):
        return self._code

    def __call__(self, weld_output_type, decoder):
        return WeldLazy(self.code, self._dependencies, weld_output_type, decoder)

def weldfunc(func):
    """
    An annotation for converting functions that return Weld strings into functions
    that return `WeldLazy`.

    A function annotated with `weldfunc` takes zero or more arguments and
    returns a Weld string as a result. The annotation converts this function
    into a one that returns a _partial function_ that returns a `WeldLazy`. The
    `WeldLazy` represents the computation constructed by the annotated
    function. The partial function takes two arguments: an output Weld type,
    and a Weld decoder.

    The decorated function automatically sets the dependencies of the resulting
    `WeldLazy`: every `WeldLazy` argument passed to the annotated function
    becomes a dependency. The annotation will unwrap `WeldLazy.id` and pass it
    to the annotated function.

    Examples
    --------
    >>> from weld.types import F64
    >>> sqrt = lambda a: "sqrt({})".format(a)
    >>> sqrt = weldfunc(sqrt)
    >>> v1 = sqrt("1")(F64(), None)
    >>> v1.expression
    'sqrt(1)'
    >>> v2 = sqrt(v1)(F64(), None) # v1 becomes a dependency of v2
    >>> v2.expression == "sqrt({})".format(v1.id)
    True
    >>> v1 in v2.children
    True

    """
    # Wrap function to maintain docs etc.
    @functools.wraps(func)
    def weldfunc_wrapper(*args, **kwargs):
        new_args = list()
        new_kwargs = dict()
        dependencies = []
        # Create a new argument list with WeldLazy replaced with their
        # id. Also add each WeldLazy to the dependencies list.
        for arg in args:
            if isinstance(arg, WeldLazy):
                new_args.append(str(arg.id))
                dependencies.append(arg)
            else:
                new_args.append(arg)
        for (k, v) in kwargs.items():
            if isinstance(v, WeldLazy):
                new_kwargs[k] = str(v.id)
                dependencies.append(v)
            else:
                new_kwargs[k] = v

        assert len(new_args) == len(args)
        assert len(new_kwargs) == len(kwargs)

        code = func(*new_args, **new_kwargs)
        return WeldFunc(code, dependencies)

    return weldfunc_wrapper
